/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#include <Columns/ColumnNullable.h>
#include <Core/Block.h>
#include <Core/Field.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
#include <Processors/Chunk.h>
#include <Processors/Executors/PushingPipelineExecutor.h>
#include <Processors/ISimpleTransform.h>
#include <Processors/Sinks/SinkToStorage.h>
#include <Storages/NativeOutputWriter.h>
#include <Storages/Output/OutputFormatFile.h>
#include <Storages/PartitionedSink.h>
#include <base/types.h>
#include <Common/ArenaUtils.h>
#include <Common/BlockTypeUtils.h>
#include <Common/CHUtil.h>
#include <Common/FieldAccurateComparison.h>

namespace local_engine
{

class NormalFileWriter : public NativeOutputWriter
{
public:
    static std::unique_ptr<NativeOutputWriter> create(
        const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & preferred_schema, const std::string & format_hint);

    NormalFileWriter(const OutputFormatFilePtr & file_, const DB::ContextPtr & context_);
    ~NormalFileWriter() override = default;

    void write(const DB::Block & block) override;
    void close() override;

private:
    DB::Block castBlock(const DB::Block & block) const;

    OutputFormatFilePtr file;
    DB::ContextPtr context;

    OutputFormatFile::OutputFormatPtr output_format;
    std::unique_ptr<DB::QueryPipeline> pipeline;
    std::unique_ptr<DB::PushingPipelineExecutor> writer;
};

OutputFormatFilePtr createOutputFormatFile(
    const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & preferred_schema, const std::string & format_hint);

struct DeltaStats
{
    size_t row_count;
    // TODO Support delta.dataSkippingStatsColumns, detail see https://docs.databricks.com/aws/en/delta/data-skipping
    size_t n_stats_cols;
    std::vector<DB::Field> min;
    std::vector<DB::Field> max;
    std::vector<Int64> null_count;
    std::set<size_t> partition_index;

    static DeltaStats create(const DB::Block & output, const DB::Names & partition)
    {
        size_t size = output.columns() - partition.size();
        std::set<size_t> partition_index;
        std::ranges::transform(
            partition,
            std::inserter(partition_index, partition_index.end()),
            [&](const auto & name) { return output.getPositionByName(name); });
        assert(partition_index.size() == partition.size());
        return DeltaStats(size, partition_index);
    }
    static DB::Block statsHeader(const DB::Block & output, const DB::Names & partition, DB::ColumnsWithTypeAndName && statsHeaderBase)
    {
        std::set<std::string> partition_index;
        std::ranges::transform(partition, std::inserter(partition_index, partition_index.end()), [&](const auto & name) { return name; });

        assert(partition_index.size() == partition.size());

        size_t num_stats_cols = numStatsCols(output.columns() - partition.size());
        auto appendBase = [&](const std::string & prefix)
        {
            for (size_t i = 0, n = 0; i < output.columns() && n < num_stats_cols; i++)
            {
                const auto & column = output.getByPosition(i);
                if (!partition_index.contains(column.name))
                {
                    statsHeaderBase.emplace_back(wrapNullableType(column.type), prefix + column.name);
                    ++n;
                }
            }
        };
        appendBase("min_");
        appendBase("max_");
        for (size_t i = 0, n = 0; i < output.columns() && n < num_stats_cols; i++)
        {
            const auto & column = output.getByPosition(i);
            if (!partition_index.contains(column.name))
            {
                statsHeaderBase.emplace_back(BIGINT(), "null_count_" + column.name);
                ++n;
            }
        }

        return DB::Block{statsHeaderBase};
    }

    static size_t numStatsCols(size_t origin)
    {
        if (DB::CurrentThread::isInitialized())
        {
            const DB::ContextPtr query_context = DB::CurrentThread::get().getQueryContext();
            if (query_context)
            {
                SparkSQLConfig config = SparkSQLConfig::loadFromContext(query_context);
                return std::min(config.deltaDataSkippingNumIndexedCols, origin);
            }
        }
        return origin;
    }

    explicit DeltaStats(size_t size, const std::set<size_t> & partition_index_ = {})
        : row_count(0)
        , n_stats_cols(numStatsCols(size))
        , min(n_stats_cols)
        , max(n_stats_cols)
        , null_count(n_stats_cols, 0)
        , partition_index(partition_index_)
    {
        assert(size > 0);
    }

    bool initialized() const { return row_count > 0; }

    void update(const DB::Chunk & chunk)
    {
        assert(chunk.getNumRows() > 0);
        const auto & columns = chunk.getColumns();
        assert(columns.size() - partition_index.size() >= n_stats_cols);
        for (size_t i = 0, col = 0; i < n_stats_cols && col < columns.size(); ++col)
        {
            if (partition_index.contains(col))
                continue;

            const auto & column = columns[col];
            Int64 null_count = 0;
            if (const auto * nullable_column = typeid_cast<const DB::ColumnNullable *>(column.get()))
            {
                const auto & null_map = nullable_column->getNullMapData();
                null_count = std::ranges::count_if(null_map, [](UInt8 value) { return value != 0; });
            }
            this->null_count[i] += null_count;

            DB::Field min_value, max_value;
            column->getExtremes(min_value, max_value);

            assert(min[i].isNull() || min_value.getType() == min[i].getType());
            assert(max[i].isNull() || max_value.getType() == max[i].getType());

            if (!initialized())
            {
                min[i] = min_value;
                max[i] = max_value;
            }
            else
            {
                min[i] = accurateLess(min[i], min_value) ? min[i] : min_value;
                max[i] = accurateLess(max[i], max_value) ? max_value : max[i];
            }
            ++i;
        }

        row_count += chunk.getNumRows();
    }

    void merge(const DeltaStats & right)
    {
        assert(n_stats_cols == right.n_stats_cols);
        assert(partition_index == right.partition_index);

        for (size_t i = 0; i < n_stats_cols; ++i)
        {
            null_count[i] += right.null_count[i];
            min[i] = std::min(min[i], right.min[i]);
            max[i] = std::max(max[i], right.max[i]);
        }
    }
};

class WriteStatsBase : public DB::ISimpleTransform
{
public:
    /// visible for UTs
    static const std::string NO_PARTITION_ID;

protected:
    bool all_chunks_processed_ = false; /// flag to determine if we have already processed all chunks
    virtual DB::Chunk final_result() = 0;

public:
    WriteStatsBase(const DB::SharedHeader & input_header_, const DB::SharedHeader & output_header_)
        : ISimpleTransform(input_header_, output_header_, true)
    {
    }

    Status prepare() override
    {
        if (input.isFinished() && !output.isFinished() && !has_input && !all_chunks_processed_)
        {
            all_chunks_processed_ = true;
            /// return Ready to call transform() for generating filling rows after latest chunk was processed
            return Status::Ready;
        }

        return ISimpleTransform::prepare();
    }

    void transform(DB::Chunk & chunk) override
    {
        if (all_chunks_processed_)
            chunk = final_result();
        else
            chunk = {};
    }
};

class WriteStats : public WriteStatsBase
{
    DB::MutableColumns columns_;

    enum ColumnIndex
    {
        filename,
        partition_id,
        record_count,
        stats_column_start = record_count + 1
    };
    static DB::ColumnsWithTypeAndName statsHeaderBase()
    {
        return {{STRING(), "filename"}, {STRING(), "partition_id"}, {BIGINT(), "record_count"}};
    }

protected:
    DB::Chunk final_result() override
    {
        size_t rows = columns_[filename]->size();
        return DB::Chunk(std::move(columns_), rows);
    }

public:
    WriteStats(const DB::SharedHeader & input_header_, const DB::SharedHeader & output_header_)
        : WriteStatsBase(input_header_, output_header_), columns_(output_header_->cloneEmptyColumns())
    {
    }
    static std::shared_ptr<WriteStats> create(const DB::SharedHeader & input, const DB::Names & partition)
    {
        return std::make_shared<WriteStats>(input, toShared(DeltaStats::statsHeader(*input, partition, statsHeaderBase())));
    }

    String getName() const override { return "WriteStats"; }

    void collectStats(const String & filename, const String & partition_dir, const DeltaStats & stats) const
    {
        const std::string & partition = partition_dir.empty() ? WriteStatsBase::NO_PARTITION_ID : partition_dir;
        size_t columnSize = stats.n_stats_cols;
        assert(columns_.size() == stats_column_start + columnSize * 3);

        columns_[ColumnIndex::filename]->insertData(filename.c_str(), filename.size());
        columns_[partition_id]->insertData(partition.c_str(), partition.size());
        auto & countColData = static_cast<DB::ColumnVector<Int64> &>(*columns_[record_count]).getData();
        countColData.emplace_back(stats.row_count);

        for (int i = 0; i < columnSize; ++i)
        {
            size_t offset = stats_column_start + i;
            columns_[offset]->insert(stats.min[i]);
            columns_[columnSize + offset]->insert(stats.max[i]);
            auto & nullCountData = static_cast<DB::ColumnVector<Int64> &>(*columns_[(columnSize * 2) + offset]).getData();
            nullCountData.emplace_back(stats.null_count[i]);
        }
    }
};

struct FileNameGenerator
{
    // Align with org.apache.spark.sql.execution.FileNamePlaceHolder
    static const std::vector<std::string> SUPPORT_PLACEHOLDERS;
    // Align with placeholders above
    const std::vector<bool> need_to_replace;
    const std::string file_pattern;

    FileNameGenerator(const std::string & file_pattern) : file_pattern(file_pattern), need_to_replace(compute_need_to_replace(file_pattern))
    {
    }

    std::vector<bool> compute_need_to_replace(const std::string & file_pattern)
    {
        std::vector<bool> result;
        for (const std::string & placeholder : SUPPORT_PLACEHOLDERS)
            if (file_pattern.find(placeholder) != std::string::npos)
                result.push_back(true);
            else
                result.push_back(false);
        return result;
    }

    std::string generate(const std::string & bucket = "") const
    {
        std::string result = file_pattern;
        if (need_to_replace[0]) // {id}
            result = pattern_format(SUPPORT_PLACEHOLDERS[0], toString(DB::UUIDHelpers::generateV4()));
        if (need_to_replace[1]) // {bucket}
            result = pattern_format(SUPPORT_PLACEHOLDERS[1], bucket);
        return result;
    }

    std::string pattern_format(const std::string & arg, const std::string & replacement) const
    {
        std::string format_str = file_pattern;
        size_t pos = format_str.find(arg);
        while (pos != std::string::npos)
        {
            format_str.replace(pos, arg.length(), replacement);
            pos = format_str.find(arg, pos + arg.length());
        }
        return format_str;
    }
};

class SubstraitFileSink final : public DB::SinkToStorage
{
    const std::string partition_id_;
    const bool bucketed_write_;
    const std::string relative_path_;
    OutputFormatFilePtr format_file_;
    OutputFormatFile::OutputFormatPtr output_format_;
    std::shared_ptr<WriteStats> stats_;
    DeltaStats delta_stats_;

    static std::string makeAbsoluteFilename(const std::string & base_path, const std::string & partition_id, const std::string & relative)
    {
        if (partition_id.empty())
            return fmt::format("{}/{}", base_path, relative);
        return fmt::format("{}/{}/{}", base_path, partition_id, relative);
    }

public:
    explicit SubstraitFileSink(
        const DB::ContextPtr & context,
        const std::string & base_path,
        const std::string & partition_id,
        const bool bucketed_write,
        const std::string & relative,
        const std::string & format_hint,
        const DB::SharedHeader header,
        const std::shared_ptr<WriteStatsBase> & stats,
        const DeltaStats & delta_stats)
        : SinkToStorage(header)
        , partition_id_(partition_id)
        , bucketed_write_(bucketed_write)
        , relative_path_(relative)
        , format_file_(createOutputFormatFile(context, makeAbsoluteFilename(base_path, partition_id, relative), *header, format_hint))
        , stats_(std::dynamic_pointer_cast<WriteStats>(stats))
        , delta_stats_(delta_stats)
    {
    }

    String getName() const override { return "SubstraitFileSink"; }

protected:
    void consume(DB::Chunk & chunk) override
    {
        delta_stats_.update(chunk);
        if (!output_format_) [[unlikely]]
            output_format_ = format_file_->createOutputFormat();

        const DB::Block & input_header = getHeader();
        if (bucketed_write_)
        {
            chunk.erase(input_header.columns() - 1);
            const DB::ColumnsWithTypeAndName & cols = input_header.getColumnsWithTypeAndName();
            DB::ColumnsWithTypeAndName without_bucket_cols(cols.begin(), cols.end() - 1);
            DB::Block without_bucket_header = DB::Block(without_bucket_cols);
            output_format_->output->write(materializeBlock(without_bucket_header.cloneWithColumns(chunk.detachColumns())));
        }
        else
            output_format_->output->write(materializeBlock(input_header.cloneWithColumns(chunk.detachColumns())));
    }
    void onFinish() override
    {
        if (output_format_)
        {
            output_format_->finalizeOutput();
            /// We need close reset output_format_ here before return to spark, because the file is closed in ~WriteBufferFromHDFSImpl().
            /// So that Spark Commit protocol can move the file safely.
            output_format_.reset();
            assert(delta_stats_.row_count > 0);
            if (stats_)
                stats_->collectStats(relative_path_, partition_id_, delta_stats_);
        }
    }
    void onCancel() noexcept override
    {
        if (output_format_)
        {
            output_format_->cancel();
            output_format_.reset();
        }
    }
};

class SparkPartitionedBaseSink : public DB::PartitionedSink
{
public:
    static const std::string DEFAULT_PARTITION_NAME;
    static const std::string BUCKET_COLUMN_NAME;

    static bool isBucketedWrite(const DB::Block & input_header)
    {
        return input_header.has(BUCKET_COLUMN_NAME) && input_header.getPositionByName(BUCKET_COLUMN_NAME) == input_header.columns() - 1;
    }

    /// visible for UTs
    static DB::ASTPtr make_partition_expression(const DB::Names & partition_columns, const DB::Block & input_header)
    {
        /// Parse the following expression into ASTs
        /// cancat('/col_name=', 'toString(col_name)')
        bool add_slash = false;
        DB::ASTs arguments;
        for (const auto & column : partition_columns)
        {
            // partition_column=
            auto column_name = std::make_shared<DB::ASTLiteral>(column);
            auto escaped_name = makeASTFunction("sparkPartitionEscape", DB::ASTs{column_name});
            if (add_slash)
                arguments.emplace_back(std::make_shared<DB::ASTLiteral>("/"));
            add_slash = true;
            arguments.emplace_back(escaped_name);
            arguments.emplace_back(std::make_shared<DB::ASTLiteral>("="));

            // ifNull(toString(partition_column), DEFAULT_PARTITION_NAME)
            // FIXME if toString(partition_column) is empty
            auto column_ast = makeASTFunction("toString", DB::ASTs{std::make_shared<DB::ASTIdentifier>(column)});
            auto escaped_value = makeASTFunction("sparkPartitionEscape", DB::ASTs{column_ast});
            DB::ASTs if_null_args{
                makeASTFunction("toString", DB::ASTs{escaped_value}), std::make_shared<DB::ASTLiteral>(DEFAULT_PARTITION_NAME)};
            arguments.emplace_back(makeASTFunction("ifNull", std::move(if_null_args)));
        }

        if (isBucketedWrite(input_header))
        {
            DB::ASTs args{std::make_shared<DB::ASTLiteral>("%05d"), std::make_shared<DB::ASTIdentifier>(BUCKET_COLUMN_NAME)};
            arguments.emplace_back(DB::makeASTFunction("printf", std::move(args)));
        }
        assert(!arguments.empty());
        if (arguments.size() == 1)
            return arguments[0];
        return DB::makeASTFunction("concat", std::move(arguments));
    }

private:
    static std::shared_ptr<DB::IPartitionStrategy>
    make_partition_strategy(const DB::ContextPtr & context, const DB::Names & partition_columns, const DB::Block & input_header)
    {
        DB::ASTPtr partition_by = make_partition_expression(partition_columns, input_header);
        return DB::PartitionStrategyFactory::get(
            DB::PartitionStrategyFactory::StrategyType::WILDCARD,
            partition_by,
            input_header.getNamesAndTypesList(),
            context,
            "", // format_name => no need
            false, // globbed_path  => no need
            true,
            true);
    }
    DB::SinkPtr createSinkForPartition(const String & partition_id) override
    {
        if (bucketed_write_)
        {
            std::string bucket_val = partition_id.substr(partition_id.length() - 5, 5);
            std::string real_partition_id = partition_id.substr(0, partition_id.length() - 5);
            return createSinkForPartition(real_partition_id, bucket_val);
        }
        return createSinkForPartition(partition_id, "");
    }

    virtual DB::SinkPtr createSinkForPartition(const String & partition_id, const String & bucket) = 0;

protected:
    DB::ContextPtr context_;
    std::shared_ptr<WriteStatsBase> stats_;
    DeltaStats empty_delta_stats_;
    bool bucketed_write_;

public:
    SparkPartitionedBaseSink(
        const DB::ContextPtr & context,
        const DB::Names & partition_by,
        const DB::SharedHeader & input_header,
        const std::shared_ptr<WriteStatsBase> & stats)
        : PartitionedSink(make_partition_strategy(context, partition_by, *input_header), context, input_header)
        , context_(context)
        , stats_(stats)
        , empty_delta_stats_(DeltaStats::create(*input_header, partition_by))
        , bucketed_write_(isBucketedWrite(*input_header))
    {
    }
};

class SubstraitPartitionedFileSink final : public SparkPartitionedBaseSink
{
    const std::string base_path_;
    const FileNameGenerator generator_;
    const DB::SharedHeader input_header_;
    const DB::SharedHeader sample_block_;
    const std::string format_hint_;

public:
    SubstraitPartitionedFileSink(
        const DB::ContextPtr & context,
        const DB::Names & partition_by,
        const DB::SharedHeader & input_header,
        const DB::SharedHeader & sample_block,
        const std::string & base_path,
        const FileNameGenerator & generator,
        const std::string & format_hint,
        const std::shared_ptr<WriteStatsBase> & stats)
        : SparkPartitionedBaseSink(context, partition_by, input_header, stats)
        , base_path_(base_path)
        , generator_(generator)
        , sample_block_(sample_block)
        , input_header_(input_header)
        , format_hint_(format_hint)
    {
    }

    DB::SinkPtr createSinkForPartition(const String & partition_id, const String & bucket) override
    {
        assert(stats_);
        bool bucketed_write = !bucket.empty();
        std::string filename = bucketed_write ? generator_.generate(bucket) : generator_.generate();
        const auto partition_path = fmt::format("{}/{}", partition_id, filename);
        validatePartitionKey(partition_path, true);
        return std::make_shared<SubstraitFileSink>(
            context_, base_path_, partition_id, bucketed_write, filename, format_hint_, sample_block_, stats_, empty_delta_stats_);
    }
    String getName() const override { return "SubstraitPartitionedFileSink"; }
};
}
